{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d98dadd5-7d62-466e-9e60-cfd8b955e510",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Taichi] version 1.4.1, llvm 15.0.4, commit e67c674e, linux, python 3.9.16\n",
      "[I 03/29/23 20:19:03.819 2437685] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import taichi as ti\n",
    "from opt import get_opts\n",
    "from train import NeRFSystem\n",
    "from datasets.ray_utils import get_rays\n",
    "from modules.intersection import RayAABBIntersector as ray_intersect\n",
    "\n",
    "def taichi_init(args):\n",
    "    taichi_init_args = {\"arch\": ti.cuda, \"device_memory_GB\": 8.0}\n",
    "    if args.half2_opt:\n",
    "        taichi_init_args[\"half2_vectorization\"] = True\n",
    "\n",
    "    ti.init(**taichi_init_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bc1b6205-5bde-49a8-bd8c-66e355e0c20e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Taichi] Starting on arch=cuda\n",
      "GridEncoding: Nmin=16 b=1.31951 F=2 T=2^19 L=16\n",
      "per_level_scale:  1.3195079107728942\n",
      "offset_:  5710032\n",
      "total_hash_size:  11420064\n",
      "Loading 100 train images ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100it [00:01, 54.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading 200 test images ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "200it [00:03, 55.70it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([Adam (\n",
       "  Parameter Group 0\n",
       "      amsgrad: False\n",
       "      betas: (0.9, 0.999)\n",
       "      capturable: False\n",
       "      differentiable: False\n",
       "      eps: 1e-15\n",
       "      foreach: None\n",
       "      fused: None\n",
       "      initial_lr: 0.01\n",
       "      lr: 0.01\n",
       "      maximize: False\n",
       "      weight_decay: 0\n",
       "  )],\n",
       " [<torch.optim.lr_scheduler.CosineAnnealingLR at 0x7f26f00c0250>])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prefix_args = [\n",
    " \"--root_dir\", \"/home/loyot/workspace/code/ngp_pl_gui/Synthetic_NeRF/Lego\",\n",
    " \"--exp_name\", \"Lego\", \"--perf\",\n",
    " \"--num_epochs\", \"20\", \n",
    " \"--batch_size\", \"8192\", \n",
    " \"--lr\", \"1e-2\", \"--no_save_test\"\n",
    "]\n",
    "hparams = get_opts(prefix_args)\n",
    "\n",
    "device = torch.device('cuda')\n",
    "taichi_init(hparams)\n",
    "system = NeRFSystem(hparams).to(device)\n",
    "system.setup(\"train\")\n",
    "system.configure_optimizers()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31af685c-f50b-4b92-8c1a-121842e4cff5",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ecb17726-99f1-46c6-937b-3663c18abac1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = system.train_dataset\n",
    "dataset.poses = dataset.poses.to(device)\n",
    "dataset.directions = dataset.directions.to(device)\n",
    "model = system.model # Taichi NGP\n",
    "optimizer = system.net_opt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9df384a7-0442-4b42-9a61-ccc17bb01918",
   "metadata": {},
   "source": [
    "## Train Step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f50802c0-a539-4063-be66-3fe70ac9a521",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get data from dataset\n",
    "batch = dataset[0]\n",
    "poses = dataset.poses[batch['img_idxs']]\n",
    "directions = dataset.directions[batch['pix_idxs']]\n",
    "rgb_gt = batch['rgb'].to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "62553481-5afd-4bea-90c5-d86ed4896186",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Update occupancy grid\n",
    "model.update_density_grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "61b4136f-372d-40c8-abb5-6df74ce64f02",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "occ_grid shape:  torch.Size([1, 2097152])\n",
      "128 x 128 x 128:  2097152\n"
     ]
    }
   ],
   "source": [
    "print(\"occ_grid shape: \", model.density_grid.shape)\n",
    "print(\"128 x 128 x 128: \", 128 **3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5c512133-414d-488b-a200-a9de0ec8f87c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate rays\n",
    "rays_o, rays_d = get_rays(directions, poses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ff297ae0-26b3-4ee0-b7bf-a5f45770ee9b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ray shape:  torch.Size([8192, 3])\n"
     ]
    }
   ],
   "source": [
    "print(\"ray shape: \", rays_o.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63b63386-2783-41e1-822a-1261a4b97d9d",
   "metadata": {},
   "source": [
    "### Render"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "115391f2-91eb-4eda-8c85-ea9bd683ca09",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ray intersection with box\n",
    "rays_o = rays_o.contiguous() # make sure buffer is contiguous\n",
    "rays_d = rays_d.contiguous()\n",
    "\n",
    "NEAR_DISTANCE = 0.01\n",
    "box_center = model.center\n",
    "box_size = model.half_size\n",
    "\n",
    "_, hits_t, _ = ray_intersect.apply(rays_o, rays_d, box_center, box_size, 1)\n",
    "\n",
    "hits_mask = (hits_t[:, 0, 0] >= 0) & (hits_t[:, 0, 0] < NEAR_DISTANCE)\n",
    "hits_t[hits_mask, 0, 0] = NEAR_DISTANCE\n",
    "hits_t = hits_t[:, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "690b3075-4baf-4ff9-8a84-9930d6c6c2ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ray marching\n",
    "MAX_SAMPLES = 1024\n",
    "\n",
    "marching_results = model.ray_marching(\n",
    "    rays_o, rays_d, # ray\n",
    "    hits_t, # near & far\n",
    "    model.density_bitfield, # occupancy grid\n",
    "    model.cascades, \n",
    "    model.scale, \n",
    "    0, \n",
    "    model.grid_size,\n",
    "    MAX_SAMPLES\n",
    ")\n",
    "\n",
    "rays_a, xyzs, dirs, deltas, ts, n_samples = marching_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4d81d7f2-0a46-4ca9-a78c-f44e5e7b6e3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xyzs shape:  torch.Size([2055705, 3])\n"
     ]
    }
   ],
   "source": [
    "# Ray level -> Sample level\n",
    "print(\"xyzs shape: \", xyzs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c1068ad0-e340-4dcb-9992-070e56a0458a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rays_a shape:  torch.Size([8192, 3])\n"
     ]
    }
   ],
   "source": [
    "print(\"rays_a shape: \", rays_a.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b001cc60-fb75-466e-9f3e-4f5fad2ed82d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([7849, 9796,  253], device='cuda:0', dtype=torch.int32)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "non_zero_mask = rays_a[:, 2] > 0\n",
    "rays_a[non_zero_mask][40]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "225d2b09-88a3-409e-93a9-6db74fc06610",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ray marching pseudo-code\n",
    "# For each ray\n",
    "for ray_id in ray_ids:\n",
    "    r = ray_id\n",
    "    n = 0\n",
    "    t1, t2 = hits_t[r, 0], hits_t[r, 1]\n",
    "    t = t1\n",
    "    while (0 <= t) & (t < t2) & (n < MAX_SAMPLES):\n",
    "        xyz = ray_o + t * ray_d\n",
    "        nxyz = ceil(nxyz)\n",
    "        occ = occupancy_grid[nxyz] \n",
    "\n",
    "        if occ:\n",
    "            t += calculate_dt(t) # next step \n",
    "            n += 1 # add sample\n",
    "            val_samples[r, n] = xyz # save sample\n",
    "        else:\n",
    "            t += advance_to_next_cell()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f2bd7576-0f62-4049-b5df-e5a9053246f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# run radiance field\n",
    "kwargs = {}\n",
    "sigmas, rgbs = model(xyzs, dirs, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ee7e41c5-9537-45e7-93b0-35218e33e47b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Volume rendering\n",
    "T_threshold = 1e-4\n",
    "render_results = model.render_func(\n",
    "    sigmas, \n",
    "    rgbs, \n",
    "    deltas, \n",
    "    ts, \n",
    "    rays_a, \n",
    "    T_threshold\n",
    ")\n",
    "_, _, _, rgb, _ = render_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2140746a-7b34-4a29-82cd-0e84407645be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rgb shape:  torch.Size([8192, 3])\n"
     ]
    }
   ],
   "source": [
    "# Sample level -> Ray level\n",
    "print(\"rgb shape: \", rgb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9cada47b-1eea-43aa-a3ae-878cd1b2f52f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Loss\n",
    "loss = ((rgb - rgb_gt)**2).mean()\n",
    "loss.backward()\n",
    "optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b80a7310-25f1-4d0d-8831-667e51512bf0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss:  tensor(0.5511, device='cuda:0', grad_fn=<MeanBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(\"loss: \", loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "158afedb-1b98-43f4-ade4-af8d8f411d38",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dd5cd9b-8650-4d03-8fd6-f227c08aa9b2",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "39c148ea-ca04-4d6d-a5cf-30041328cfb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from modules.networks import MLP\n",
    "from modules.hash_encoder import HashEncoder\n",
    "from modules.spherical_harmonics import DirEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "64490268-b59a-4d51-802d-e1ae74e3dfcf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "per_level_scale:  1.3195079107728942\n",
      "offset_:  5710032\n",
      "total_hash_size:  11420064\n"
     ]
    }
   ],
   "source": [
    "# Get box size\n",
    "box_min = model.xyz_min\n",
    "box_max = model.xyz_max\n",
    "\n",
    "# Encoders\n",
    "L = 16\n",
    "min_res = 16\n",
    "max_res = 1024\n",
    "b = np.exp(np.log(max_res / min_res) / (L - 1))\n",
    "pos_encoder = HashEncoder( # or Triplane\n",
    "    b,\n",
    "    hparams.batch_size # pre-allocation\n",
    ").to(device)\n",
    "dir_encoder = DirEncoder(hparams.batch_size).to(device)\n",
    "\n",
    "# MLP\n",
    "sigma_net = \\\n",
    "    MLP(\n",
    "        input_dim=32,\n",
    "        output_dim=16,\n",
    "        net_depth=1,\n",
    "        net_width=64,\n",
    "        bias_enabled=False,\n",
    "    ).to(device)\n",
    "\n",
    "rgb_net = \\\n",
    "    MLP(\n",
    "        input_dim=32,\n",
    "        output_dim=3,\n",
    "        net_depth=2,\n",
    "        net_width=64,\n",
    "        bias_enabled=False,\n",
    "        output_activation=torch.nn.Sigmoid()\n",
    "    ).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "61be2869-4b60-4277-af20-d1f7daa1f7e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model forward\n",
    "class TruncExp(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, x):\n",
    "        ctx.save_for_backward(x)\n",
    "        return torch.exp(x)\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, dL_dout):\n",
    "        x = ctx.saved_tensors[0]\n",
    "        return dL_dout * torch.exp(x.clamp(-15, 15))\n",
    "\n",
    "# Get sigma\n",
    "\n",
    "x = (xyzs - box_min)/ (box_max - box_min)\n",
    "pos_embedding = pos_encoder(x)\n",
    "h = sigma_net(pos_embedding)\n",
    "sigmas = TruncExp.apply(h[:, 0])\n",
    "\n",
    "# Get rgb\n",
    "d = dirs / torch.norm(dirs, dim=1, keepdim=True)\n",
    "dir_embedding = dir_encoder((d + 1) / 2) # [-1, 1] -> [0, 1]\n",
    "rgbs = rgb_net(torch.cat([dir_embedding, h], 1))\n",
    "\n",
    "# return sigmas, rgbs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "37c3a4b4-9cd7-42aa-970e-9ce9501b75a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sigmas shape:  torch.Size([2074017])\n",
      "rgbs shape:  torch.Size([2074017, 3])\n"
     ]
    }
   ],
   "source": [
    "print(\"sigmas shape: \", sigmas.shape)\n",
    "print(\"rgbs shape: \", rgbs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5071da9-b167-4c94-b4a0-2dd997b644fe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch_nightly]",
   "language": "python",
   "name": "conda-env-torch_nightly-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
